import math
from pathlib import Path

import numpy as np
import pandas as pd
from linearmodels.iv import IV2SLS


MISSING_CODES = {7, 8, 9, 77, 88, 99, 777, 888, 999, 7777, 8888, 9999}


def clean_numeric(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    s = s.mask(s.isin(MISSING_CODES))
    return s


def recode_yes_no(series: pd.Series) -> pd.Series:
    s = clean_numeric(series)
    return s.map({1: 1.0, 2: 0.0})


def zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m = s.mean()
    sd = s.std(ddof=0)
    if pd.isna(sd) or sd <= 0:
        return s * 0
    return (s - m) / sd


def build_dataset(root: Path) -> pd.DataFrame:
    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    reg = pd.read_csv(root / "data_external" / "broadband_region_year_eurostat_isoc_r_broad_h.csv")

    ess = ess[ess["regunit"].isin([1, 2])].copy()
    ess["region"] = ess["region"].astype(str).str.upper().str.strip()
    ess = ess[ess["region"] != "99999"].copy()

    reg["region"] = reg["region"].astype(str).str.upper().str.strip()
    reg["year"] = pd.to_numeric(reg["year"], errors="coerce").astype("Int64")
    reg["broadband_pc_hh"] = pd.to_numeric(reg.get("broadband_pc_hh"), errors="coerce")
    reg["internet_access_pc_hh"] = pd.to_numeric(reg.get("internet_access_pc_hh"), errors="coerce")

    df = ess.merge(reg, left_on=["region", "survey_year"], right_on=["region", "year"], how="left")

    # Endogenous
    df["netusoft"] = clean_numeric(df["netusoft"])

    # Instruments (region-year)
    df["z_broadband_pc_hh"] = zscore(df["broadband_pc_hh"])
    df["z_internet_access_pc_hh"] = zscore(df["internet_access_pc_hh"])

    # Outcomes
    df["y_vote"] = recode_yes_no(df["vote"])
    df["y_sgnptit"] = recode_yes_no(df["sgnptit"])
    df["y_pbldmn"] = recode_yes_no(df["pbldmn"])
    df["y_bctprd"] = recode_yes_no(df["bctprd"])
    df["y_contplt"] = recode_yes_no(df["contplt"])
    df["y_badge"] = recode_yes_no(df["badge"])
    df["y_pstplonl"] = recode_yes_no(df["pstplonl"]) if "pstplonl" in df.columns else np.nan
    part_cols = ["y_sgnptit", "y_pbldmn", "y_bctprd", "y_contplt", "y_badge", "y_pstplonl"]
    df["y_part_index"] = df[part_cols].mean(axis=1, skipna=True)
    df.loc[df[part_cols].isna().all(axis=1), "y_part_index"] = np.nan

    # Controls
    df["agea"] = clean_numeric(df["agea"])
    df["gndr"] = clean_numeric(df["gndr"])
    df["eduyrs"] = clean_numeric(df["eduyrs"])
    df["hinctnta"] = clean_numeric(df["hinctnta"])

    # Education indicator
    df["edu_high"] = np.where(df["eduyrs"].notna() & (df["eduyrs"] > 12), 1.0, np.nan)
    df.loc[df["eduyrs"].notna() & (df["eduyrs"] <= 12), "edu_high"] = 0.0

    # FE key
    df["ct"] = df["cntry"].astype(str) + "_" + df["survey_year"].astype("Int64").astype(str)

    # Keep rows where instruments exist and edu_high exists
    df = df[df[["z_broadband_pc_hh", "z_internet_access_pc_hh"]].notna().any(axis=1)].copy()
    df = df[df["edu_high"].notna()].copy()

    return df


def fit_iv_interaction(df: pd.DataFrame, y: str) -> tuple[object, pd.Series]:
    """
    Model:
      y = b1*netusoft + b2*(netusoft*edu_high) + controls + ctFE + e
    Instruments:
      z_broadband, z_internet_access, and their interactions with edu_high
    """
    cols = [
        y,
        "netusoft",
        "edu_high",
        "z_broadband_pc_hh",
        "z_internet_access_pc_hh",
        "agea",
        "gndr",
        "eduyrs",
        "hinctnta",
        "ct",
        "region",
    ]
    d = df[cols].dropna().copy()

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    d["z_broadband_x_edu"] = d["z_broadband_pc_hh"] * d["edu_high"]
    d["z_inetacc_x_edu"] = d["z_internet_access_pc_hh"] * d["edu_high"]

    # FE via ct dummies
    d["ct"] = d["ct"].astype(str)
    ct_d = pd.get_dummies(d["ct"], prefix="ct", drop_first=True)

    exog = pd.concat(
        [
            pd.DataFrame({"const": 1.0}, index=d.index),
            d[["agea", "gndr", "eduyrs", "hinctnta", "edu_high"]],
            ct_d,
        ],
        axis=1,
    )
    exog = exog.loc[:, ~exog.columns.duplicated()]

    endog = d[["netusoft", "netusoft_x_edu"]]
    instr = d[["z_broadband_pc_hh", "z_internet_access_pc_hh", "z_broadband_x_edu", "z_inetacc_x_edu"]]

    mod = IV2SLS(d[y], exog=exog, endog=endog, instruments=instr)
    res = mod.fit(cov_type="clustered", clusters=d["region"])

    return res, d["region"]


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_path = root / "outputs" / "iv_interaction_edu.txt"

    df = build_dataset(root)
    years = sorted(int(x) for x in df["survey_year"].dropna().unique())
    countries = sorted(df["cntry"].dropna().astype(str).unique().tolist())

    lines: list[str] = []
    lines.append("IV with education interaction (region-year instruments; ctFE; clustered by region)")
    lines.append(f"Years in sample: {years}")
    lines.append(f"Countries in sample: {len(countries)}")
    lines.append(f"Rows (edu_high non-missing): {len(df):,}")
    lines.append("")

    outcomes = [
        ("y_part_index", "Participation index"),
        ("y_vote", "Voted"),
        ("y_sgnptit", "Signed petition"),
    ]

    for y, label in outcomes:
        lines.append(f"== {label} ({y}) ==")
        res, _clusters = fit_iv_interaction(df, y)
        b1 = res.params.get("netusoft", float("nan"))
        b2 = res.params.get("netusoft_x_edu", float("nan"))
        se1 = res.std_errors.get("netusoft", float("nan"))
        se2 = res.std_errors.get("netusoft_x_edu", float("nan"))

        # Implied effects:
        # edu_low: b1
        # edu_high: b1 + b2
        lines.append(f"coef(netusoft) [edu_low effect] = {b1:.4f} (se={se1:.4f})")
        lines.append(f"coef(netusoft_x_edu) [delta high-low] = {b2:.4f} (se={se2:.4f})")
        lines.append(f"implied edu_high effect = {b1 + b2:.4f}")

        # First-stage diagnostics (for both endogenous regressors)
        try:
            diag = res.first_stage.diagnostics
            lines.append("first_stage diagnostics:")
            for idx in diag.index:
                fs_f = float(diag.loc[idx].get("f.stat", float("nan")))
                fs_p = float(diag.loc[idx].get("f.pval", float("nan")))
                pr2 = float(diag.loc[idx].get("partial.rsquared", float("nan")))
                lines.append(f"  {idx}: partial_R2={pr2:.4f} F={fs_f:.2f} p={fs_p:.4g}")
        except Exception as e:
            lines.append(f"first_stage diagnostics: (na) {type(e).__name__}: {e}")

        lines.append("")

    out_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
    print(f"Wrote: {out_path}")


if __name__ == "__main__":
    main()

